{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "GlueGen Stable Diffusion Pipeline\n",
    "\n",
    "GlueGen is a minimal adapter that allows alignment between any encoder (Text Encoder of different language, Multilingual Roberta, AudioClip) and CLIP text encoder used in standard Stable Diffusion model. This method allows easy language adaptation to available english Stable Diffusion checkpoints without the need of an image captioning dataset as well as long training hours.\n",
    "\n",
    "Make sure you downloaded `gluenet_French_clip_overnorm_over3_noln.ckpt` for French (there are also pre-trained weights for Chinese, Italian, Japanese, Spanish or train your own) at [GlueGen's official repo](https://github.com/salesforce/GlueGen/tree/main). This script was contributed by [Phạm Hồng Vinh](https://github.com/rootonchair) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.32.2)\n",
      "Requirement already satisfied: transformers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (4.48.3)\n",
      "Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.6.0)\n",
      "Requirement already satisfied: pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (9.0.0)\n",
      "Requirement already satisfied: sentencepiece in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.2.0)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.6.1)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.17.0)\n",
      "Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.28.1)\n",
      "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2.32.3)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.5.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (24.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (6.0.2)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.21.0)\n",
      "Requirement already satisfied: tqdm>=4.27 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (4.67.1)\n",
      "Requirement already satisfied: typing-extensions>=4.10.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.5)\n",
      "Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2025.2.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.5.8)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.2.1.3)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (10.3.5.147)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.6.1.9)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.3.1.170)\n",
      "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (0.6.2)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.4.127)\n",
      "Requirement already satisfied: triton==3.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.2.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.4.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2.3.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2025.1.31)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install diffusers transformers torch pillow sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt\n",
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt\n",
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt\n",
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt\n",
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt\n",
      "Downloading ./checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt...\n",
      "Downloaded ./checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt\n",
      "Using device: cuda\n",
      "Initializing XLM-RoBERTa...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "- This IS expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing CLIP models...\n",
      "Initializing Diffusion Pipeline...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Expected types for tokenizer: ['AutoTokenizer'], got CLIPTokenizer.\n",
      "Expected types for text_encoder: ['AutoModel'], got CLIPTextModel.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a999a4258f5438f9ea291212b5df004",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Processing French model...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48f54a321b1f4af1a4fd5b26c0a2e441",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image saved to outputs/gluegen_output_french.png\n",
      "\n",
      "Processing complete.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import gc\n",
    "import urllib.request\n",
    "import torch\n",
    "from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM, CLIPTokenizer, CLIPTextModel\n",
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "# Download checkpoints\n",
    "CHECKPOINTS = [\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Chinese_clip_overnorm_over3_noln.ckpt\",\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_French_clip_overnorm_over3_noln.ckpt\",\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Italian_clip_overnorm_over3_noln.ckpt\",\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Japanese_clip_overnorm_over3_noln.ckpt\",\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_Spanish_clip_overnorm_over3_noln.ckpt\",\n",
    "    \"https://storage.googleapis.com/sfr-gluegen-data-research/checkpoints_all/gluenet_checkpoint/gluenet_sound2img_audioclip_us8k.ckpt\"\n",
    "]\n",
    "\n",
    "LANGUAGE_PROMPTS = {\n",
    "    \"French\": \"une voiture sur la plage\",\n",
    "    #\"Chinese\": \"海滩上的一辆车\",\n",
    "    #\"Italian\": \"una macchina sulla spiaggia\",\n",
    "    #\"Japanese\": \"浜辺の車\",\n",
    "    #\"Spanish\": \"un coche en la playa\"\n",
    "}\n",
    "\n",
    "def download_checkpoints(checkpoint_dir):\n",
    "    os.makedirs(checkpoint_dir, exist_ok=True)\n",
    "    for url in CHECKPOINTS:\n",
    "        filename = os.path.join(checkpoint_dir, os.path.basename(url))\n",
    "        if not os.path.exists(filename):\n",
    "            print(f\"Downloading {filename}...\")\n",
    "            urllib.request.urlretrieve(url, filename)\n",
    "            print(f\"Downloaded {filename}\")\n",
    "        else:\n",
    "            print(f\"Checkpoint {filename} already exists, skipping download.\")\n",
    "    return checkpoint_dir\n",
    "\n",
    "def load_checkpoint(pipeline, checkpoint_path, device):\n",
    "    state_dict = torch.load(checkpoint_path, map_location=device)\n",
    "    state_dict = state_dict.get(\"state_dict\", state_dict)\n",
    "    missing_keys, unexpected_keys = pipeline.unet.load_state_dict(state_dict, strict=False)\n",
    "    return pipeline\n",
    "\n",
    "def generate_image(pipeline, prompt, device, output_path):\n",
    "    with torch.inference_mode():\n",
    "        image = pipeline(\n",
    "            prompt,\n",
    "            generator=torch.Generator(device=device).manual_seed(42),\n",
    "            num_inference_steps=50\n",
    "        ).images[0]\n",
    "        image.save(output_path)\n",
    "        print(f\"Image saved to {output_path}\")\n",
    "\n",
    "checkpoint_dir = download_checkpoints(\"./checkpoints_all/gluenet_checkpoint\")\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "tokenizer = XLMRobertaTokenizer.from_pretrained(\"xlm-roberta-base\", use_fast=False)\n",
    "model = XLMRobertaForMaskedLM.from_pretrained(\"xlm-roberta-base\").to(device)\n",
    "inputs = tokenizer(\"Ceci est une phrase incomplète avec un [MASK].\", return_tensors=\"pt\").to(device)\n",
    "with torch.inference_mode():\n",
    "    _ = model(**inputs)\n",
    "\n",
    "\n",
    "clip_tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
    "clip_text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device)\n",
    "\n",
    "# Initialize pipeline\n",
    "pipeline = DiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n",
    "    text_encoder=clip_text_encoder,\n",
    "    tokenizer=clip_tokenizer,\n",
    "    custom_pipeline=\"gluegen\",\n",
    "    safety_checker=None\n",
    ").to(device)\n",
    "\n",
    "os.makedirs(\"outputs\", exist_ok=True)\n",
    "\n",
    "# Generate images\n",
    "for language, prompt in LANGUAGE_PROMPTS.items():\n",
    "\n",
    "    checkpoint_file = f\"gluenet_{language}_clip_overnorm_over3_noln.ckpt\"\n",
    "    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_file)\n",
    "    try:\n",
    "        pipeline = load_checkpoint(pipeline, checkpoint_path, device)\n",
    "        output_path = f\"outputs/gluegen_output_{language.lower()}.png\"\n",
    "        generate_image(pipeline, prompt, device, output_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing {language} model: {e}\")\n",
    "        continue\n",
    "\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.empty_cache()\n",
    "    gc.collect()\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
